{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-04-23T09:24:15.427545Z",
     "start_time": "2021-04-23T09:24:15.015898Z"
    },
    "execution": {
     "iopub.execute_input": "2021-04-24T00:51:33.640418Z",
     "iopub.status.busy": "2021-04-24T00:51:33.639880Z",
     "iopub.status.idle": "2021-04-24T00:51:38.503727Z",
     "shell.execute_reply": "2021-04-24T00:51:38.502437Z",
     "shell.execute_reply.started": "2021-04-24T00:51:33.640371Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--2021-04-24 08:51:33--  https://mirror.coggle.club/dataset/ner/msra.zip\n",
      "正在解析主机 mirror.coggle.club (mirror.coggle.club)... 60.217.246.15\n",
      "正在连接 mirror.coggle.club (mirror.coggle.club)|60.217.246.15|:443... 已连接。\n",
      "已发出 HTTP 请求，正在等待回应... 200 OK\n",
      "长度： 7676370 (7.3M) [application/zip]\n",
      "正在保存至: “msra.zip”\n",
      "\n",
      "msra.zip            100%[===================>]   7.32M  1.92MB/s    用时 3.9s    \n",
      "\n",
      "2021-04-24 08:51:38 (1.87 MB/s) - 已保存 “msra.zip” [7676370/7676370])\n",
      "\n",
      "Archive:  msra.zip\n",
      "   creating: msra/\n",
      "   creating: msra/test/\n",
      "  inflating: msra/test/sentences.txt  \n",
      "  inflating: msra/test/tags.txt      \n",
      "  inflating: msra/msra_test_bio      \n",
      "  inflating: msra/msra_train_bio     \n",
      "   creating: msra/train/\n",
      "  inflating: msra/train/sentences.txt  \n",
      "  inflating: msra/train/tags.txt     \n",
      "  inflating: msra/tags.txt           \n",
      "   creating: msra/val/\n",
      "  inflating: msra/val/sentences.txt  \n",
      "  inflating: msra/val/tags.txt       \n"
     ]
    }
   ],
   "source": [
    "!wget https://mirror.coggle.club/dataset/ner/msra.zip\n",
    "!unzip msra.zip"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-04-23T13:34:52.838479Z",
     "start_time": "2021-04-23T13:34:52.455079Z"
    },
    "execution": {
     "iopub.execute_input": "2021-04-24T12:48:16.632349Z",
     "iopub.status.busy": "2021-04-24T12:48:16.631788Z",
     "iopub.status.idle": "2021-04-24T12:48:16.721436Z",
     "shell.execute_reply": "2021-04-24T12:48:16.720983Z",
     "shell.execute_reply.started": "2021-04-24T12:48:16.632302Z"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "import codecs\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-04-23T13:34:54.138977Z",
     "start_time": "2021-04-23T13:34:53.477451Z"
    },
    "execution": {
     "iopub.execute_input": "2021-04-24T12:48:17.260018Z",
     "iopub.status.busy": "2021-04-24T12:48:17.259481Z",
     "iopub.status.idle": "2021-04-24T12:48:17.662302Z",
     "shell.execute_reply": "2021-04-24T12:48:17.661766Z",
     "shell.execute_reply.started": "2021-04-24T12:48:17.259971Z"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "tag_type = ['O', 'B-ORG', 'I-ORG', 'B-PER', 'I-PER', 'B-LOC', 'I-LOC']\n",
    "# B-ORG I-ORG 机构的开始位置和中间位置\n",
    "# B-PER I-PER 人物名字的开始位置和中间位置\n",
    "# B-LOC I-LOC 位置的开始位置和中间位置\n",
    "\n",
    "train_lines = codecs.open('msra/train/sentences.txt').readlines()\n",
    "train_lines = [x.replace(' ', '').strip() for x in train_lines]\n",
    "\n",
    "train_tags = codecs.open('msra/train/tags.txt').readlines()\n",
    "train_tags = [x.strip().split(' ') for x in train_tags]\n",
    "train_tags = [[tag_type.index(x) for x in tag] for tag in train_tags]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-04-24T12:48:18.742526Z",
     "iopub.status.busy": "2021-04-24T12:48:18.741982Z",
     "iopub.status.idle": "2021-04-24T12:48:18.752852Z",
     "shell.execute_reply": "2021-04-24T12:48:18.752204Z",
     "shell.execute_reply.started": "2021-04-24T12:48:18.742480Z"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "train_lines, train_tags = train_lines[:20000], train_tags[:20000] "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-04-23T14:53:47.364732Z",
     "start_time": "2021-04-23T14:53:47.359559Z"
    },
    "execution": {
     "iopub.execute_input": "2021-04-24T12:48:19.125637Z",
     "iopub.status.busy": "2021-04-24T12:48:19.125127Z",
     "iopub.status.idle": "2021-04-24T12:48:19.331877Z",
     "shell.execute_reply": "2021-04-24T12:48:19.330797Z",
     "shell.execute_reply.started": "2021-04-24T12:48:19.125592Z"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "val_lines = codecs.open('msra/val/sentences.txt').readlines()\n",
    "val_lines = [x.replace(' ', '').strip() for x in val_lines]\n",
    "\n",
    "val_tags = codecs.open('msra/val/tags.txt').readlines()\n",
    "val_tags = [x.strip().split(' ') for x in val_tags]\n",
    "val_tags = [[tag_type.index(x) for x in tag] for tag in val_tags]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-04-23T13:35:14.896288Z",
     "start_time": "2021-04-23T13:35:02.759708Z"
    },
    "execution": {
     "iopub.execute_input": "2021-04-24T12:48:21.871049Z",
     "iopub.status.busy": "2021-04-24T12:48:21.870502Z",
     "iopub.status.idle": "2021-04-24T12:48:31.550685Z",
     "shell.execute_reply": "2021-04-24T12:48:31.550074Z",
     "shell.execute_reply.started": "2021-04-24T12:48:21.871003Z"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "from transformers import BertTokenizer\n",
    "tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')\n",
    "train_encoding = tokenizer(list(train_lines), truncation=True, padding=True, max_length=64)\n",
    "val_encoding = tokenizer(list(val_lines), truncation=True, padding=True, max_length=64)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-04-23T13:35:17.120298Z",
     "start_time": "2021-04-23T13:35:16.751314Z"
    },
    "execution": {
     "iopub.execute_input": "2021-04-24T12:49:16.594534Z",
     "iopub.status.busy": "2021-04-24T12:49:16.593957Z",
     "iopub.status.idle": "2021-04-24T12:49:17.018653Z",
     "shell.execute_reply": "2021-04-24T12:49:17.018088Z",
     "shell.execute_reply.started": "2021-04-24T12:49:16.594488Z"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch.utils.data import Dataset, DataLoader, TensorDataset\n",
    "\n",
    "class TextDataset(Dataset):\n",
    "    def __init__(self, encodings, labels):\n",
    "        self.encodings = encodings\n",
    "        self.labels = labels\n",
    "        \n",
    "    def __getitem__(self, idx):\n",
    "        item = {key: torch.tensor(val[idx])[:64] for key, val in self.encodings.items()}\n",
    "        # 字级别的标注\n",
    "        item['labels'] = torch.tensor([0] + self.labels[idx] + [0] * (63-len(self.labels[idx])))[:64]\n",
    "        return item\n",
    "    \n",
    "    def __len__(self):\n",
    "        return len(self.labels)\n",
    "\n",
    "train_dataset = TextDataset(train_encoding, train_tags[:])\n",
    "test_dataset = TextDataset(val_encoding, val_tags[:])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-04-24T12:49:25.212607Z",
     "iopub.status.busy": "2021-04-24T12:49:25.212040Z",
     "iopub.status.idle": "2021-04-24T12:49:25.226301Z",
     "shell.execute_reply": "2021-04-24T12:49:25.225693Z",
     "shell.execute_reply.started": "2021-04-24T12:49:25.212561Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'input_ids': tensor([ 101, 1963,  862, 6237, 1104, 6639, 4413, 4518, 7270, 3309, 2100, 1762,\n",
       "         4638, 6436, 1914, 4757, 4688, 8024, 7028, 2920, 3212, 3189, 3823, 7305,\n",
       "         6639, 4413, 4638, 7413, 7599, 8024, 2768,  711, 1921, 3823, 6639, 1781,\n",
       "          677,  678, 1079, 1912, 1168, 1905, 6379, 6389, 4638, 6413, 7579,  511,\n",
       "          102,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,\n",
       "            0,    0,    0,    0]),\n",
       " 'token_type_ids': tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
       "         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
       "         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]),\n",
       " 'attention_mask': tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
       "         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
       "         1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]),\n",
       " 'labels': tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5, 6,\n",
       "         0, 0, 0, 0, 0, 0, 0, 0, 5, 6, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
       "         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])}"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_dataset[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-04-24T02:30:40.720554Z",
     "iopub.status.busy": "2021-04-24T02:30:40.720381Z",
     "iopub.status.idle": "2021-04-24T02:30:41.476863Z",
     "shell.execute_reply": "2021-04-24T02:30:41.476337Z",
     "shell.execute_reply.started": "2021-04-24T02:30:40.720540Z"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "for idx in range(len(train_dataset)):\n",
    "    item = train_dataset[idx]\n",
    "    for key in item:\n",
    "        if item[key].shape[0] != 64:\n",
    "            print(key, item[key].shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-04-24T02:30:44.593987Z",
     "iopub.status.busy": "2021-04-24T02:30:44.593460Z",
     "iopub.status.idle": "2021-04-24T02:30:44.702663Z",
     "shell.execute_reply": "2021-04-24T02:30:44.702273Z",
     "shell.execute_reply.started": "2021-04-24T02:30:44.593943Z"
    },
    "scrolled": true,
    "tags": []
   },
   "outputs": [],
   "source": [
    "for idx in range(len(test_dataset)):\n",
    "    item = test_dataset[idx]\n",
    "    for key in item:\n",
    "        if item[key].shape[0] != 64:\n",
    "            print(key, item[key].shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from transformers import BertForTokenClassification, AdamW, get_linear_schedule_with_warmup\n",
    "model = BertForTokenClassification.from_pretrained('bert-base-chinese', num_labels=7)\n",
    "\n",
    "device = torch.device(\"cuda:1\" if torch.cuda.is_available() else \"cpu\")\n",
    "model.to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-04-23T13:35:18.744186Z",
     "start_time": "2021-04-23T13:35:18.183277Z"
    },
    "execution": {
     "iopub.execute_input": "2021-04-24T03:13:20.803134Z",
     "iopub.status.busy": "2021-04-24T03:13:20.802591Z",
     "iopub.status.idle": "2021-04-24T03:13:28.896245Z",
     "shell.execute_reply": "2021-04-24T03:13:28.895455Z",
     "shell.execute_reply.started": "2021-04-24T03:13:20.803086Z"
    },
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of the model checkpoint at bert-base-chinese were not used when initializing BertForTokenClassification: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias']\n",
      "- This IS expected if you are initializing BertForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
      "- This IS NOT expected if you are initializing BertForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
      "Some weights of BertForTokenClassification were not initialized from the model checkpoint at bert-base-chinese and are newly initialized: ['classifier.weight', 'classifier.bias']\n",
      "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
     ]
    },
    {
     "ename": "NameError",
     "evalue": "name 'DataLoader' is not defined",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mNameError\u001b[0m                                 Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-4-6b86db8eef56>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m      5\u001b[0m \u001b[0mdevice\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"cuda:1\"\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcuda\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mis_available\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0;34m\"cpu\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      6\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 7\u001b[0;31m \u001b[0mtrain_loader\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mDataLoader\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrain_dataset\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbatch_size\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m32\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mshuffle\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m      8\u001b[0m \u001b[0mtest_dataloader\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mDataLoader\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtest_dataset\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbatch_size\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m32\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mshuffle\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      9\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mNameError\u001b[0m: name 'DataLoader' is not defined"
     ]
    }
   ],
   "source": [
    "train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)\n",
    "test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=True)\n",
    "\n",
    "optim = AdamW(model.parameters(), lr=5e-5)\n",
    "total_steps = len(train_loader) * 1\n",
    "scheduler = get_linear_schedule_with_warmup(optim, \n",
    "                                            num_warmup_steps = 0, # Default value in run_glue.py\n",
    "                                            num_training_steps = total_steps)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-04-24T03:13:55.446275Z",
     "iopub.status.busy": "2021-04-24T03:13:55.445698Z",
     "iopub.status.idle": "2021-04-24T03:13:55.711722Z",
     "shell.execute_reply": "2021-04-24T03:13:55.710653Z",
     "shell.execute_reply.started": "2021-04-24T03:13:55.446227Z"
    }
   },
   "outputs": [],
   "source": [
    "model = torch.load('bert-ner.pt')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-04-23T13:37:36.436445Z",
     "start_time": "2021-04-23T13:35:34.993720Z"
    },
    "execution": {
     "iopub.execute_input": "2021-04-24T02:36:53.191816Z",
     "iopub.status.busy": "2021-04-24T02:36:53.191231Z",
     "iopub.status.idle": "2021-04-24T02:45:04.763608Z",
     "shell.execute_reply": "2021-04-24T02:45:04.763157Z",
     "shell.execute_reply.started": "2021-04-24T02:36:53.191764Z"
    },
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "------------Epoch: 0 ----------------\n",
      "0.2490234375 1.7682684659957886\n",
      "0.984375 0.08177318423986435\n",
      "0.9794921875 0.10324091464281082\n",
      "0.9853515625 0.05329243093729019\n",
      "0.974609375 0.0980382040143013\n",
      "epoth: 0, iter_num: 100, loss: 0.0187, 16.00%\n",
      "0.98486328125 0.03968340530991554\n",
      "0.97802734375 0.13193301856517792\n",
      "0.96875 0.09083674848079681\n",
      "0.99462890625 0.030755499377846718\n",
      "0.96435546875 0.15113528072834015\n",
      "epoth: 0, iter_num: 200, loss: 0.0309, 32.00%\n",
      "0.98046875 0.05407458916306496\n",
      "0.9755859375 0.04200287163257599\n",
      "0.96337890625 0.10279929637908936\n",
      "0.9755859375 0.05761774629354477\n",
      "0.99560546875 0.027676289901137352\n",
      "epoth: 0, iter_num: 300, loss: 0.0450, 48.00%\n",
      "0.99072265625 0.04058162495493889\n",
      "0.99169921875 0.03472297638654709\n",
      "0.9833984375 0.0438205860555172\n",
      "0.9912109375 0.02867487072944641\n",
      "0.98486328125 0.05152392014861107\n",
      "epoth: 0, iter_num: 400, loss: 0.0556, 64.00%\n",
      "0.99267578125 0.020176555961370468\n",
      "0.9833984375 0.03304135426878929\n",
      "0.99365234375 0.021856384351849556\n",
      "0.99609375 0.025094078853726387\n",
      "0.98388671875 0.02991459146142006\n",
      "epoth: 0, iter_num: 500, loss: 0.0585, 80.00%\n",
      "0.98388671875 0.06794300675392151\n",
      "0.9970703125 0.0023129573091864586\n",
      "0.98291015625 0.047733839601278305\n",
      "0.99365234375 0.01836448535323143\n",
      "0.98876953125 0.046061575412750244\n",
      "epoth: 0, iter_num: 600, loss: 0.0329, 96.00%\n",
      "0.9892578125 0.02512132190167904\n",
      "0.97802734375 0.030762100592255592\n",
      "Epoch: 0, Average training loss: 0.0778\n",
      "Accuracy: 0.9885\n",
      "Average testing loss: 0.0397\n",
      "-------------------------------\n",
      "------------Epoch: 1 ----------------\n",
      "0.994140625 0.012883431278169155\n",
      "0.9921875 0.028787333518266678\n",
      "0.986328125 0.03391276299953461\n",
      "0.9912109375 0.03425929322838783\n",
      "0.98779296875 0.048689037561416626\n",
      "epoth: 1, iter_num: 100, loss: 0.0263, 16.00%\n",
      "0.99072265625 0.01574419066309929\n",
      "0.98291015625 0.03798341006040573\n",
      "0.9912109375 0.04045302793383598\n",
      "0.99755859375 0.015362639911472797\n",
      "0.98828125 0.02518577128648758\n",
      "epoth: 1, iter_num: 200, loss: 0.0305, 32.00%\n",
      "0.994140625 0.00803142599761486\n",
      "0.9912109375 0.02742576226592064\n",
      "0.990234375 0.03907283768057823\n",
      "0.99267578125 0.01931551843881607\n",
      "0.99072265625 0.012429164722561836\n",
      "epoth: 1, iter_num: 300, loss: 0.0139, 48.00%\n",
      "0.99609375 0.01702895201742649\n",
      "0.99072265625 0.041424695402383804\n",
      "0.990234375 0.03367319330573082\n",
      "0.98828125 0.04208147153258324\n",
      "0.990234375 0.02715323492884636\n",
      "epoth: 1, iter_num: 400, loss: 0.0317, 64.00%\n",
      "0.9892578125 0.030402926728129387\n",
      "0.9912109375 0.030082494020462036\n",
      "0.9912109375 0.0359673909842968\n",
      "0.982421875 0.03305638208985329\n",
      "0.99609375 0.013682955875992775\n",
      "epoth: 1, iter_num: 500, loss: 0.0263, 80.00%\n",
      "0.97900390625 0.052617914974689484\n",
      "0.98388671875 0.0583026185631752\n",
      "0.9775390625 0.014709308743476868\n",
      "0.9853515625 0.038522034883499146\n",
      "0.9931640625 0.020008670166134834\n",
      "epoth: 1, iter_num: 600, loss: 0.0446, 96.00%\n",
      "0.97802734375 0.07021111994981766\n",
      "0.98291015625 0.05172831192612648\n",
      "Epoch: 1, Average training loss: 0.0311\n",
      "Accuracy: 0.9885\n",
      "Average testing loss: 0.0397\n",
      "-------------------------------\n",
      "------------Epoch: 2 ----------------\n",
      "0.990234375 0.03839073330163956\n",
      "0.9873046875 0.016635265201330185\n",
      "0.98828125 0.045931074768304825\n",
      "0.98486328125 0.026105398312211037\n",
      "0.99365234375 0.019270632416009903\n",
      "epoth: 2, iter_num: 100, loss: 0.0441, 16.00%\n",
      "0.9912109375 0.05012635886669159\n",
      "0.98681640625 0.036785077303647995\n",
      "0.97802734375 0.07566747814416885\n",
      "0.9931640625 0.018448136746883392\n",
      "0.99462890625 0.013631744310259819\n",
      "epoth: 2, iter_num: 200, loss: 0.0300, 32.00%\n",
      "0.98828125 0.028544338420033455\n",
      "0.99267578125 0.022399254143238068\n",
      "0.9873046875 0.04734417051076889\n",
      "0.990234375 0.01969706267118454\n",
      "0.98583984375 0.019324446097016335\n",
      "epoth: 2, iter_num: 300, loss: 0.0204, 48.00%\n",
      "0.99365234375 0.019294846802949905\n",
      "0.9912109375 0.042378801852464676\n",
      "0.994140625 0.020513538271188736\n",
      "0.98193359375 0.05390065535902977\n",
      "0.99462890625 0.004402865655720234\n",
      "epoth: 2, iter_num: 400, loss: 0.0411, 64.00%\n",
      "0.99365234375 0.010571631602942944\n",
      "0.98486328125 0.03408976271748543\n",
      "0.98486328125 0.03624831140041351\n",
      "0.98779296875 0.015979593619704247\n",
      "0.98779296875 0.046796899288892746\n",
      "epoth: 2, iter_num: 500, loss: 0.0268, 80.00%\n",
      "0.9951171875 0.02482364885509014\n",
      "0.99365234375 0.026390183717012405\n",
      "0.9931640625 0.02190655842423439\n",
      "0.9931640625 0.014970648102462292\n",
      "0.990234375 0.02842547930777073\n",
      "epoth: 2, iter_num: 600, loss: 0.0222, 96.00%\n",
      "0.99267578125 0.023658493533730507\n",
      "0.982421875 0.015109711326658726\n",
      "Epoch: 2, Average training loss: 0.0312\n",
      "Accuracy: 0.9885\n",
      "Average testing loss: 0.0400\n",
      "-------------------------------\n",
      "------------Epoch: 3 ----------------\n",
      "0.984375 0.017112642526626587\n",
      "0.9921875 0.03326643258333206\n",
      "0.97509765625 0.06490886956453323\n",
      "0.9970703125 0.008432931266725063\n",
      "0.99169921875 0.0269821397960186\n",
      "epoth: 3, iter_num: 100, loss: 0.0286, 16.00%\n",
      "0.99267578125 0.01527893915772438\n",
      "0.97900390625 0.058531519025564194\n",
      "0.98583984375 0.04331326484680176\n",
      "0.9892578125 0.04045962914824486\n",
      "0.9912109375 0.029991429299116135\n",
      "epoth: 3, iter_num: 200, loss: 0.0099, 32.00%\n",
      "0.990234375 0.015588161535561085\n",
      "0.9921875 0.02152552641928196\n",
      "0.98876953125 0.036023806780576706\n",
      "0.99658203125 0.008520455099642277\n",
      "0.98486328125 0.03546988219022751\n",
      "epoth: 3, iter_num: 300, loss: 0.0699, 48.00%\n",
      "0.986328125 0.030712468549609184\n",
      "0.98828125 0.020864032208919525\n",
      "0.99267578125 0.00887017697095871\n",
      "0.99658203125 0.01782744750380516\n",
      "0.99755859375 0.015181810595095158\n",
      "epoth: 3, iter_num: 400, loss: 0.0831, 64.00%\n",
      "0.9912109375 0.011080128140747547\n",
      "0.98876953125 0.02567201666533947\n",
      "0.986328125 0.026015648618340492\n",
      "0.99072265625 0.011677887290716171\n",
      "0.98779296875 0.04932735115289688\n",
      "epoth: 3, iter_num: 500, loss: 0.0104, 80.00%\n",
      "0.99560546875 0.011631021276116371\n",
      "0.990234375 0.02875502221286297\n",
      "0.9912109375 0.04603218287229538\n",
      "0.978515625 0.04179846867918968\n",
      "0.9931640625 0.018458472564816475\n",
      "epoth: 3, iter_num: 600, loss: 0.0429, 96.00%\n",
      "0.9921875 0.022404665127396584\n",
      "0.99169921875 0.03245164453983307\n",
      "Epoch: 3, Average training loss: 0.0311\n",
      "Accuracy: 0.9885\n",
      "Average testing loss: 0.0397\n",
      "-------------------------------\n"
     ]
    }
   ],
   "source": [
    "from tqdm import tqdm\n",
    "\n",
    "def train():\n",
    "    model.train()\n",
    "    total_train_loss = 0\n",
    "    iter_num = 0\n",
    "    total_iter = len(train_loader)\n",
    "    for idx, batch in enumerate(train_loader):\n",
    "        optim.zero_grad()\n",
    "        \n",
    "        input_ids = batch['input_ids'].to(device)\n",
    "        attention_mask = batch['attention_mask'].to(device)\n",
    "        labels = batch['labels'].to(device)\n",
    "        outputs = model(input_ids, attention_mask=attention_mask, labels=labels)\n",
    "            # loss = outputs[0]\n",
    "\n",
    "        loss = outputs.loss\n",
    "        \n",
    "        if idx % 20 == 0:\n",
    "            with torch.no_grad():\n",
    "                # 64 * 7\n",
    "                print((outputs[1].argmax(2).data == labels.data).float().mean().item(), loss.item())\n",
    "        \n",
    "        total_train_loss += loss.item()\n",
    "        loss.backward()\n",
    "        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)\n",
    "        optim.step()\n",
    "        scheduler.step()\n",
    "\n",
    "        iter_num += 1\n",
    "        if(iter_num % 100==0):\n",
    "            print(\"epoth: %d, iter_num: %d, loss: %.4f, %.2f%%\" % (epoch, iter_num, loss.item(), iter_num/total_iter*100))\n",
    "        \n",
    "    print(\"Epoch: %d, Average training loss: %.4f\"%(epoch, total_train_loss/len(train_loader)))\n",
    "    \n",
    "def validation():\n",
    "    model.eval()\n",
    "    total_eval_accuracy = 0\n",
    "    total_eval_loss = 0\n",
    "    for batch in test_dataloader:\n",
    "        with torch.no_grad():\n",
    "            input_ids = batch['input_ids'].to(device)\n",
    "            attention_mask = batch['attention_mask'].to(device)\n",
    "            labels = batch['labels'].to(device)\n",
    "            outputs = model(input_ids, attention_mask=attention_mask, labels=labels)\n",
    "        loss = outputs.loss\n",
    "        logits = outputs[1]\n",
    "\n",
    "        total_eval_loss += loss.item()\n",
    "        logits = logits.detach().cpu().numpy()\n",
    "        label_ids = labels.to('cpu').numpy()\n",
    "        total_eval_accuracy += (outputs[1].argmax(2).data == labels.data).float().mean().item()\n",
    "        \n",
    "    avg_val_accuracy = total_eval_accuracy / len(test_dataloader)\n",
    "    print(\"Accuracy: %.4f\" % (avg_val_accuracy))\n",
    "    print(\"Average testing loss: %.4f\"%(total_eval_loss/len(test_dataloader)))\n",
    "    print(\"-------------------------------\")\n",
    "    \n",
    "\n",
    "for epoch in range(4):\n",
    "    print(\"------------Epoch: %d ----------------\" % epoch)\n",
    "    train()\n",
    "    validation()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-04-24T02:45:11.533512Z",
     "iopub.status.busy": "2021-04-24T02:45:11.532931Z",
     "iopub.status.idle": "2021-04-24T02:45:12.163488Z",
     "shell.execute_reply": "2021-04-24T02:45:12.162999Z",
     "shell.execute_reply.started": "2021-04-24T02:45:11.533464Z"
    }
   },
   "outputs": [],
   "source": [
    "torch.save(model, 'bert-ner.pt')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-04-23T14:51:18.819401Z",
     "start_time": "2021-04-23T14:51:18.807242Z"
    },
    "execution": {
     "iopub.execute_input": "2021-04-24T03:13:59.374545Z",
     "iopub.status.busy": "2021-04-24T03:13:59.374001Z",
     "iopub.status.idle": "2021-04-24T03:13:59.382786Z",
     "shell.execute_reply": "2021-04-24T03:13:59.382132Z",
     "shell.execute_reply.started": "2021-04-24T03:13:59.374497Z"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "tag_type = ['O', 'B-ORG', 'I-ORG', 'B-PER', 'I-PER', 'B-LOC', 'I-LOC']\n",
    "\n",
    "def predcit(s):\n",
    "    item = tokenizer([s], truncation=True, padding='longest', max_length=64)\n",
    "    with torch.no_grad():\n",
    "        input_ids = torch.tensor(item['input_ids']).to(device).reshape(1, -1)\n",
    "        attention_mask = torch.tensor(item['attention_mask']).to(device).reshape(1, -1)\n",
    "        labels = torch.tensor([0] * attention_mask.shape[1]).to(device).reshape(1, -1)\n",
    "        \n",
    "        outputs = model(input_ids, attention_mask, labels)\n",
    "        outputs = outputs[0].data.cpu().numpy()\n",
    "        \n",
    "    outputs = outputs[0].argmax(1)[1:-1]\n",
    "    ner_result = ''\n",
    "    ner_flag = ''\n",
    "    \n",
    "    for o, c in zip(outputs,s):\n",
    "        # 0 就是 O，没有含义\n",
    "        if o == 0 and ner_result == '':\n",
    "            continue\n",
    "        \n",
    "        # \n",
    "        elif o == 0 and ner_result != '':\n",
    "            if ner_flag == 'O':\n",
    "                print('机构：', ner_result)\n",
    "            if ner_flag == 'P':\n",
    "                print('人名：', ner_result)\n",
    "            if ner_flag == 'L':\n",
    "                print('位置：', ner_result)\n",
    "                \n",
    "            ner_result = ''\n",
    "        \n",
    "        elif o != 0:\n",
    "            ner_flag = tag_type[o][2]\n",
    "            ner_result += c\n",
    "    return outputs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-04-23T14:54:57.736608Z",
     "start_time": "2021-04-23T14:54:57.706272Z"
    },
    "execution": {
     "iopub.execute_input": "2021-04-24T03:13:59.848271Z",
     "iopub.status.busy": "2021-04-24T03:13:59.847730Z",
     "iopub.status.idle": "2021-04-24T03:13:59.869450Z",
     "shell.execute_reply": "2021-04-24T03:13:59.868791Z",
     "shell.execute_reply.started": "2021-04-24T03:13:59.848225Z"
    },
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "位置： 华盛顿\n",
      "位置： 美国\n",
      "位置： 白宫\n",
      "位置： 菲律宾\n",
      "位置： 马拉卡南宫\n"
     ]
    }
   ],
   "source": [
    "s = '整个华盛顿已笼罩在一片夜色之中，一个电话从美国总统府白宫打到了菲律宾总统府马拉卡南宫。'\n",
    "data = predcit(s)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-04-23T14:54:58.297655Z",
     "start_time": "2021-04-23T14:54:58.291706Z"
    },
    "execution": {
     "iopub.execute_input": "2021-04-24T03:07:44.697612Z",
     "iopub.status.busy": "2021-04-24T03:07:44.697074Z",
     "iopub.status.idle": "2021-04-24T03:07:44.711205Z",
     "shell.execute_reply": "2021-04-24T03:07:44.710752Z",
     "shell.execute_reply.started": "2021-04-24T03:07:44.697565Z"
    },
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "位置： 中国\n",
      "位置： 美国\n"
     ]
    }
   ],
   "source": [
    "s = '人工智能是未来的希望，也是中国和美国的冲突点。'\n",
    "data = predcit(s)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-04-23T14:54:42.166525Z",
     "start_time": "2021-04-23T14:54:42.134935Z"
    },
    "execution": {
     "iopub.execute_input": "2021-04-24T03:08:02.948129Z",
     "iopub.status.busy": "2021-04-24T03:08:02.947567Z",
     "iopub.status.idle": "2021-04-24T03:08:02.976157Z",
     "shell.execute_reply": "2021-04-24T03:08:02.975655Z",
     "shell.execute_reply.started": "2021-04-24T03:08:02.948085Z"
    },
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "位置： 海淀\n",
      "人名： 刘涛\n",
      "人名： 王华\n"
     ]
    }
   ],
   "source": [
    "s = '明天我们一起在海淀吃个饭吧，把叫刘涛和王华也叫上。'\n",
    "data = predcit(s)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-04-24T03:14:29.774924Z",
     "iopub.status.busy": "2021-04-24T03:14:29.774354Z",
     "iopub.status.idle": "2021-04-24T03:14:29.808248Z",
     "shell.execute_reply": "2021-04-24T03:14:29.807664Z",
     "shell.execute_reply.started": "2021-04-24T03:14:29.774876Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "机构： 同煤集团同生安平煤业公司\n"
     ]
    }
   ],
   "source": [
    "s = '同煤集团同生安平煤业公司发生井下安全事故 19名矿工遇难'\n",
    "data = predcit(s)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-04-24T03:14:56.992793Z",
     "iopub.status.busy": "2021-04-24T03:14:56.992229Z",
     "iopub.status.idle": "2021-04-24T03:14:57.046461Z",
     "shell.execute_reply": "2021-04-24T03:14:57.045922Z",
     "shell.execute_reply.started": "2021-04-24T03:14:56.992745Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "机构： 山东省政府办公厅\n",
      "机构： 平邑县玉荣商贸有限公司\n"
     ]
    }
   ],
   "source": [
    "s = '山东省政府办公厅就平邑县玉荣商贸有限公司石膏矿坍塌事故发出通报'\n",
    "data = predcit(s)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-04-24T03:16:04.482434Z",
     "iopub.status.busy": "2021-04-24T03:16:04.481856Z",
     "iopub.status.idle": "2021-04-24T03:16:04.514867Z",
     "shell.execute_reply": "2021-04-24T03:16:04.514221Z",
     "shell.execute_reply.started": "2021-04-24T03:16:04.482387Z"
    },
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "位置： 黑龙江\n",
      "机构： 龙煤集团\n"
     ]
    }
   ],
   "source": [
    "s = '[新闻直播间]黑龙江:龙煤集团一煤矿发生火灾事故'\n",
    "data = predcit(s)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
